Skip to content

Conversation

@waridrox
Copy link
Contributor

@waridrox waridrox commented May 26, 2025

Closes #1316, Waiting on iteration after merge conflicts that will arise from #1669, #1701, #1709, etc.

Added CalibrationCurveDisplay class, with visualization for classifier probability calibration assessment. Implementation follows scikit-learn's API patterns (ref) while extending functionality to support both binary and multiclass classification.

Summary:

  • Implemented CalibrationCurveDisplay class adapting usage from scikit-learn
  • Added support for both binary and multiclass classification.
  • Implemented different bin strategies ("uniform" and "quantile")
  • Introduced DisplayClassProtocol for type safety across display classes (This was introduced to deal with Union types and same class method names).

Testing

Added tests to cover binary classification scenarios.
Added tests to cover multi-class classification scenarios.

TODO:


Example file calibration plots:

  1. Single Estimator Calibration Curve
# Calibration curve - Logistic Regression
# Create the figure and axes
fig, (ax, hist_ax) = plt.subplots(
    nrows=2, figsize=(8, 8), height_ratios=[2, 1], sharex=True
)

# Plot the calibration curve
display.plot(
    ax=ax,
    hist_ax=hist_ax,
    line_kwargs={"color": "darkblue", "marker": "o", "linewidth": 2},
    ref_line_kwargs={
        "color": "black",
        "linestyle": "--",
        "alpha": 0.8,
        "linewidth": 1.5,
    },
    hist_kwargs={"color": "skyblue", "alpha": 0.5},
)
Screenshot 2025-05-25 at 3 57 49 PM
  1. Multiple Estimators Comparison
# Calibration curves for different classifiers
# Compares: Logistic Regression, Naive Bayes, Random Forest, and SVC
# Create plot with custom styling
fig, (ax, hist_ax) = plt.subplots(
    nrows=2, figsize=(10, 10), height_ratios=[2, 1], sharex=True
)

# Plot with custom styling
display.plot(
    ax=ax,
    hist_ax=hist_ax,
    line_kwargs={"linewidth": 2, "alpha": 0.8, "marker": "o", "markersize": 5},
    ref_line_kwargs={
        "color": "black",
        "linestyle": "--",
        "alpha": 0.7,
        "linewidth": 1.5,
    },
    hist_kwargs={"alpha": 0.3},
)
Screenshot 2025-05-25 at 3 58 05 PM
  1. Calibrated vs Uncalibrated Models
# Calibration curves for uncalibrated vs. calibrated models
# Compares: Random Forest with different calibration methods (uncalibrated, sigmoid, isotonic)
# Create a nice plot
fig, (ax, hist_ax) = plt.subplots(
    nrows=2, figsize=(10, 10), height_ratios=[2, 1], sharex=True
)

# Plot with custom styling
display.plot(
    ax=ax,
    hist_ax=hist_ax,
    line_kwargs={"linewidth": 2, "alpha": 0.8, "marker": "o", "markersize": 6},
    ref_line_kwargs={
        "color": "black",
        "linestyle": "--",
        "alpha": 0.7,
        "linewidth": 1.5,
    },
    hist_kwargs={"alpha": 0.3},
)
Screenshot 2025-05-25 at 3 58 24 PM
  1. Different Binning Strategies
# Calibration plots for different strategies and bin counts
# Compares: Different binning strategies (uniform vs quantile) with different bin counts
# Setup the figure with GridSpec
fig = plt.figure(figsize=(12, 10))
gs = GridSpec(3, 2, height_ratios=[2, 1, 1])

# Create a single plot for all calibration curves
ax_calibration_curve = fig.add_subplot(gs[0, :])
ax_calibration_curve.set_title(
    "Calibration plots for different strategies and bin counts"
)
Screenshot 2025-05-25 at 3 58 34 PM
  1. Sklearn-Style Comparison (with GridSpec)
# Calibration plots
# Compares: Multiple classifiers in scikit-learn style layout
# Setup the figure with GridSpec
fig = plt.figure(figsize=(12, 10))
gs = GridSpec(4, 2, height_ratios=[2, 2, 1, 1])

# Create a single plot for all calibration curves
ax_calibration_curve = fig.add_subplot(gs[:2, :])
ax_calibration_curve.set_title("Calibration plots", fontsize=16)
Screenshot 2025-05-25 at 3 58 46 PM
  1. Multi-class classification
# Calibration curves for multiclass classification
# Shows: Calibration curves for a three-class classification problem
# Setup the figure with GridSpec
fig = plt.figure(figsize=(15, 10))
gs = GridSpec(3, 3, height_ratios=[2, 1, 1])

# Create a single plot for all calibration curves
ax_calibration_curve = fig.add_subplot(gs[0, :])
ax_calibration_curve.set_title(
    "Calibration curves for multiclass classification", fontsize=16
)
Screenshot 2025-05-25 at 3 59 04 PM

Would need to identify what PRs currently in-progress might affect this before iterating on something critical. Hence, currently setting this to a draft version.
CC: @MarieSacksick @auguste-probabl

@glemaitre
Copy link
Member

glemaitre commented May 26, 2025

In terms of API, I think that we will want to have the calibration display in another accessor than metric. It is more of a diagnosis tool so we need to think where to add it.

I'm also thinking that we should enlarge the scope of the tool potentially and call it a reliability diagram that would cover the classification and regression case.

Also I did not yet go into the code but we need to think twice before implementing the multiclass case. Use a one vs rest approach might not be ideal since some output of classifiers are not using this strategy as a prediction function. I was looking at https://arxiv.org/pdf/2112.10327 and https://arxiv.org/abs/2210.16315. It seems that you have several trade-off and we should look at it thoroughly.

@waridrox
Copy link
Contributor Author

Also I did not yet go into the code but we need to think twice before implementing the multiclass case. Use a one vs rest approach might not be ideal since some output of classifiers are not using this strategy as a prediction function. I was looking at https://arxiv.org/pdf/2112.10327 and it seems that you have several trade-off and we should look at it thoroughly.

Sure, to be fair I was just building upon the proposed implementation in https://github.com/probabl-ai/skore/pull/1315/files, but these architectural decisions are indeed crucial for feature implementation.

@glemaitre
Copy link
Member

Yep, no worries. It was to put some first thoughts on my side such that we don't forget it. We could always restrict the scope to binary classification and regression because it is better defined and then check the multiclass one if it is the integration.

@thomass-dev
Copy link
Collaborator

thomass-dev commented May 26, 2025

[automated comment] Please update your PR with main, so that the pytest workflow status will be reported.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

feat(skore): Display the calibration curve for models

3 participants